Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[mlir] Add bubbling patterns for non intersecting reshapes #94637

Closed
wants to merge 1 commit into from

Conversation

Max191
Copy link
Contributor

@Max191 Max191 commented Jun 6, 2024

This PR adds fusion by expansion patterns to push a tensor.expand_shape up through a tensor.collapse_shape with non-intersecting reassociations. Sometimes parallel collapse_shape ops like this can block propagation of expand_shape ops, so this allows them to pass through each other.

@Max191
Copy link
Contributor Author

Max191 commented Jun 6, 2024

This is based on #94631. Please only review the last commit.

@llvmbot
Copy link
Collaborator

llvmbot commented Jun 6, 2024

@llvm/pr-subscribers-mlir-tensor
@llvm/pr-subscribers-mlir-linalg

@llvm/pr-subscribers-mlir

Author: None (Max191)

Changes

This PR adds fusion by expansion patterns to push a tensor.expand_shape up through a tensor.collapse_shape with non-intersecting reassociations. Sometimes parallel collapse_shape ops like this can block propagation of expand_shape ops, so this allows them to pass through each other.


Full diff: https://github.com/llvm/llvm-project/pull/94637.diff

4 Files Affected:

  • (modified) mlir/include/mlir/Dialect/Utils/ReshapeOpsUtils.h (+42-10)
  • (modified) mlir/lib/Dialect/Linalg/Transforms/ElementwiseOpFusion.cpp (+71)
  • (modified) mlir/test/Dialect/Linalg/reshape_fusion.mlir (+34)
  • (modified) mlir/test/Dialect/Tensor/canonicalize.mlir (+70-2)
diff --git a/mlir/include/mlir/Dialect/Utils/ReshapeOpsUtils.h b/mlir/include/mlir/Dialect/Utils/ReshapeOpsUtils.h
index e8f6edc3f133e..96f0f7bf1aa49 100644
--- a/mlir/include/mlir/Dialect/Utils/ReshapeOpsUtils.h
+++ b/mlir/include/mlir/Dialect/Utils/ReshapeOpsUtils.h
@@ -85,21 +85,51 @@ bool isReassociationValid(ArrayRef<AffineMap> reassociation,
 template <typename ReshapeOpTy, typename InverseReshapeOpTy>
 static OpFoldResult foldReshapeOp(ReshapeOpTy reshapeOp,
                                   ArrayRef<Attribute> operands) {
-
+  // Fold identity reshape.
   if (reshapeOp.getSrcType() == reshapeOp.getType())
     return reshapeOp.getSrc();
 
-  // Fold producer-consumer reshape ops where the operand type of the
-  // producer is same as the return type of the consumer.
-  auto reshapeSrcOp =
-      reshapeOp.getSrc().template getDefiningOp<InverseReshapeOpTy>();
-  if (reshapeSrcOp && reshapeSrcOp.getSrcType() == reshapeOp.getResultType())
-    return reshapeSrcOp.getSrc();
-
   // Reshape of a constant can be replaced with a new constant.
   if (auto elements = dyn_cast_or_null<DenseElementsAttr>(operands.front()))
     return elements.reshape(cast<ShapedType>(reshapeOp.getResult().getType()));
 
+  // Fold if the producer reshape source has the same shape with at most 1
+  // dynamic dimension.
+  auto reshapeSrcOp =
+      reshapeOp.getSrc().template getDefiningOp<InverseReshapeOpTy>();
+  if (!reshapeSrcOp)
+    return nullptr;
+  auto srcType = reshapeSrcOp.getSrcType();
+  auto resultType = reshapeOp.getResultType();
+  if (srcType != resultType)
+    return nullptr;
+
+  if (llvm::count_if(srcType.getShape(), ShapedType::isDynamic) < 2) {
+    return reshapeSrcOp.getSrc();
+  }
+
+  // Fold producer-consumer reshape ops when they are perfect inverses of each
+  // other:
+  //   1) Reassociation indices are equivalent.
+  //   2) Boundary types are equivalent.
+  //   3) No reassociations have more than 1 dynamic dimension, and reassociated
+  //      shapes are equal for each reassociation.
+  auto reassociations = reshapeOp.getReassociationIndices();
+  if (reassociations != reshapeSrcOp.getReassociationIndices())
+    return nullptr;
+  // If the reshapes are expanding and then collapsing, the ops can be folded
+  // despite multiple dynamic dimensions.
+  if (srcType.getRank() < reshapeSrcOp.getResultType().getRank())
+    return reshapeSrcOp.getSrc();
+  ArrayRef<int64_t> expandedSrcShape = srcType.getShape();
+  ArrayRef<int64_t> expandedResultShape = resultType.getShape();
+  if (llvm::all_of(reassociations, [&](auto reInd) {
+        ArrayRef<int64_t> srcSlice =
+            expandedSrcShape.slice(reInd.front(), reInd.size());
+        return llvm::count_if(srcSlice, ShapedType::isDynamic) < 2;
+      })) {
+    return reshapeSrcOp.getSrc();
+  }
   return nullptr;
 }
 
@@ -360,10 +390,12 @@ struct ComposeExpandOfCollapseOp : public OpRewritePattern<ExpandOpTy> {
           resultShape.slice(resultIndices.front(), resultIndices.size());
 
       if (srcSubShape.size() == resultSubShape.size()) {
-        if (srcSubShape == resultSubShape)
+        if (srcSubShape == resultSubShape &&
+            llvm::count_if(srcSubShape, ShapedType::isDynamic) < 2) {
           composedReassociation.push_back(srcIndices);
-        else
+        } else {
           return std::nullopt;
+        }
       }
 
       // Find reassociation to collapse `srcSubShape` into `resultSubShape`.
diff --git a/mlir/lib/Dialect/Linalg/Transforms/ElementwiseOpFusion.cpp b/mlir/lib/Dialect/Linalg/Transforms/ElementwiseOpFusion.cpp
index ad313c2d5ce60..579116904aad2 100644
--- a/mlir/lib/Dialect/Linalg/Transforms/ElementwiseOpFusion.cpp
+++ b/mlir/lib/Dialect/Linalg/Transforms/ElementwiseOpFusion.cpp
@@ -1023,6 +1023,76 @@ struct FoldReshapeWithGenericOpByExpansion
 private:
   ControlFusionFn controlFoldingReshapes;
 };
+
+/// Pattern to bubble up a tensor.expand_shape op through a producer
+/// tensor.collapse_shape op that has non intersecting reassociations.
+struct BubbleUpExpandThroughParallelCollapse
+    : public OpRewritePattern<tensor::ExpandShapeOp> {
+  using OpRewritePattern<tensor::ExpandShapeOp>::OpRewritePattern;
+
+  LogicalResult matchAndRewrite(tensor::ExpandShapeOp expandOp,
+                                PatternRewriter &rewriter) const override {
+    auto collapseOp =
+        expandOp.getSrc().getDefiningOp<tensor::CollapseShapeOp>();
+    if (!collapseOp || !collapseOp->hasOneUse())
+      return failure();
+    auto expandReInds = expandOp.getReassociationIndices();
+    auto collapseReInds = collapseOp.getReassociationIndices();
+
+    // Reshapes are parallel to each other if none of the reassociation indices
+    // have greater than 1 index for both reshapes.
+    for (auto [expandReassociation, collapseReassociation] :
+         llvm::zip_equal(expandReInds, collapseReInds)) {
+      if (collapseReassociation.size() != 1 && expandReassociation.size() != 1)
+        return failure();
+    }
+
+    // Compute new reassociation indices and expanded/collaped shapes.
+    SmallVector<ReassociationIndices> newExpandReInds, newCollapseReInds;
+    Location loc = expandOp->getLoc();
+    SmallVector<OpFoldResult> collapseSizes =
+        tensor::getMixedSizes(rewriter, loc, collapseOp.getSrc());
+    SmallVector<OpFoldResult> expandSizes(getMixedValues(
+        expandOp.getStaticOutputShape(), expandOp.getOutputShape(), rewriter));
+    SmallVector<OpFoldResult> newExpandSizes;
+    int64_t index = 0, expandIndex = 0, collapseIndex = 0;
+    for (auto [idx, collapseReassociation] : llvm::enumerate(collapseReInds)) {
+      if (collapseReassociation.size() != 1) {
+        ReassociationIndices newCollapseReassociation;
+        for (size_t i = 0; i < collapseReassociation.size(); ++i) {
+          newCollapseReassociation.push_back(index);
+          newExpandReInds.push_back({index++});
+          newExpandSizes.push_back(collapseSizes[collapseIndex++]);
+        }
+        newCollapseReInds.push_back(newCollapseReassociation);
+        expandIndex++;
+        continue;
+      }
+      ReassociationIndices newExpandReassociation;
+      auto expandReassociation = expandReInds[idx];
+      for (size_t i = 0; i < expandReassociation.size(); ++i) {
+        newExpandReassociation.push_back(index);
+        newCollapseReInds.push_back({index++});
+        newExpandSizes.push_back(expandSizes[expandIndex++]);
+      }
+      newExpandReInds.push_back(newExpandReassociation);
+      collapseIndex++;
+    }
+
+    // Swap reshape order.
+    SmallVector<Value> dynamicSizes;
+    SmallVector<int64_t> staticSizes;
+    dispatchIndexOpFoldResults(newExpandSizes, dynamicSizes, staticSizes);
+    auto expandResultType = expandOp.getResultType().clone(staticSizes);
+    auto newExpand = rewriter.create<tensor::ExpandShapeOp>(
+        loc, expandResultType, collapseOp.getSrc(), newExpandReInds,
+        newExpandSizes);
+    rewriter.replaceOpWithNewOp<tensor::CollapseShapeOp>(
+        expandOp, newExpand.getResult(), newCollapseReInds);
+    return success();
+  }
+};
+
 } // namespace
 
 //===---------------------------------------------------------------------===//
@@ -1939,6 +2009,7 @@ void mlir::linalg::populateFoldReshapeOpsByExpansionPatterns(
                                                     controlFoldingReshapes);
   patterns.add<FoldWithProducerReshapeOpByExpansion>(patterns.getContext(),
                                                      controlFoldingReshapes);
+  patterns.add<BubbleUpExpandThroughParallelCollapse>(patterns.getContext());
 }
 
 void mlir::linalg::populateFoldReshapeOpsByCollapsingPatterns(
diff --git a/mlir/test/Dialect/Linalg/reshape_fusion.mlir b/mlir/test/Dialect/Linalg/reshape_fusion.mlir
index f42666f81bbad..1354b138983a0 100644
--- a/mlir/test/Dialect/Linalg/reshape_fusion.mlir
+++ b/mlir/test/Dialect/Linalg/reshape_fusion.mlir
@@ -826,3 +826,37 @@ func.func @linalg_add_reshape_producer_fusion(%arg0 : tensor<?x7x?x8xf32>,
 // CHECK-SAME:     [0, 1], [2, 3]
 // CHECK-SAME:     tensor<?x7x?x8xf32> into tensor<?x?xf32>
 //      CHECK:   return %[[T4]]
+
+// -----
+
+func.func @bubble_parallel_reshapes(%arg0: tensor<?x?x?x?xf32>, %s0: index, %s1: index, %s2: index, %s3: index) -> tensor<?x?x?x?xf32> {
+  %collapse = tensor.collapse_shape %arg0 [[0], [1, 2], [3]] : tensor<?x?x?x?xf32> into tensor<?x?x?xf32>
+  %expand = tensor.expand_shape %collapse [[0], [1], [2, 3]]
+              output_shape [%s0, %s1, %s2, %s3] : tensor<?x?x?xf32> into tensor<?x?x?x?xf32>
+  return %expand : tensor<?x?x?x?xf32>
+}
+//      CHECK: func @bubble_parallel_reshapes
+// CHECK-SAME:   %[[ARG0:.+]]: tensor<?x?x?x?xf32>
+// CHECK-SAME:   %[[S0:.+]]: index, %[[S1:.+]]: index, %[[S2:.+]]: index, %[[S3:.+]]: index
+//  CHECK-DAG:   %[[C1:.+]] = arith.constant 1 : index
+//  CHECK-DAG:   %[[C2:.+]] = arith.constant 2 : index
+//  CHECK-DAG:   %[[DIM1:.+]] = tensor.dim %[[ARG0]], %[[C1]] : tensor<?x?x?x?xf32>
+//  CHECK-DAG:   %[[DIM2:.+]] = tensor.dim %[[ARG0]], %[[C2]] : tensor<?x?x?x?xf32>
+//      CHECK:   %[[EXPAND:.+]] = tensor.expand_shape %[[ARG0]] {{\[}}[0], [1], [2], [3, 4]]
+// CHECK-SAME:       output_shape [%[[S0]], %[[DIM1]], %[[DIM2]], %[[S2]], %[[S3]]] : tensor<?x?x?x?xf32> into tensor<?x?x?x?x?xf32>
+//      CHECK:   %[[COLLAPSE:.+]] = tensor.collapse_shape %[[EXPAND]] {{\[}}[0], [1, 2], [3], [4]] : tensor<?x?x?x?x?xf32> into tensor<?x?x?x?xf32>
+//      CHECK:   return %[[COLLAPSE]]
+
+// -----
+
+func.func @no_bubble_intersecting_reshapes(%arg0: tensor<?x?x?x?xf32>, %s0: index, %s1: index, %s2: index, %s3: index) -> tensor<?x?x?x?xf32> {
+  %collapse = tensor.collapse_shape %arg0 [[0], [1, 2], [3]] : tensor<?x?x?x?xf32> into tensor<?x?x?xf32>
+  %expand = tensor.expand_shape %collapse [[0], [1, 2], [3]]
+              output_shape [%s0, %s1, %s2, %s3] : tensor<?x?x?xf32> into tensor<?x?x?x?xf32>
+  return %expand : tensor<?x?x?x?xf32>
+}
+//      CHECK: func @no_bubble_intersecting_reshapes
+// CHECK-SAME:   %[[ARG0:.+]]: tensor<?x?x?x?xf32>
+//      CHECK:   %[[COLLAPSE:.+]] = tensor.collapse_shape %[[ARG0]] {{\[}}[0], [1, 2], [3]]
+//      CHECK:   %[[EXPAND:.+]] = tensor.expand_shape %[[COLLAPSE]] {{\[}}[0], [1, 2], [3]]
+//      CHECK:   return %[[EXPAND]]
diff --git a/mlir/test/Dialect/Tensor/canonicalize.mlir b/mlir/test/Dialect/Tensor/canonicalize.mlir
index f7fbd3834288b..9a6b03986ccb6 100644
--- a/mlir/test/Dialect/Tensor/canonicalize.mlir
+++ b/mlir/test/Dialect/Tensor/canonicalize.mlir
@@ -1139,7 +1139,7 @@ func.func @fold_collapse_of_expand(%arg0 : tensor<12x4xf32>) -> tensor<12x4xf32>
   return %1 : tensor<12x4xf32>
 }
 // CHECK-LABEL: @fold_collapse_of_expand
-//   CHECK-NOT:   linalg.{{.*}}shape
+//   CHECK-NOT:   tensor.{{.*}}_shape
 
 // -----
 
@@ -1152,7 +1152,75 @@ func.func @fold_collapse_of_expand_dynamic(%arg0 : tensor<?x?xf32>, %arg1: index
   return %1 : tensor<?x?xf32>
 }
 // CHECK-LABEL: @fold_collapse_of_expand_dynamic
-//   CHECK-NOT:   linalg.{{.*}}_shape
+//   CHECK-NOT:   tensor.{{.*}}_shape
+
+// -----
+
+func.func @fold_collapse_of_expand_fully_dynamic(%arg0 : tensor<?x?xf32>, %arg1: index, %arg2: index, %arg3: index)
+    -> tensor<?x?xf32> {
+  %0 = tensor.expand_shape %arg0 [[0, 1], [2]] output_shape [%arg1, %arg2, %arg3]
+      : tensor<?x?xf32> into tensor<?x?x?xf32>
+  %1 = tensor.collapse_shape %0 [[0, 1], [2]]
+      : tensor<?x?x?xf32> into tensor<?x?xf32>
+  return %1 : tensor<?x?xf32>
+}
+// CHECK-LABEL: @fold_collapse_of_expand_fully_dynamic
+//   CHECK-NOT:   tensor.{{.*}}_shape
+
+// -----
+
+func.func @no_fold_parallel_collapse_of_expand_dynamic(%arg0 : tensor<?x?x?xf32>, %arg1: index, %arg2: index, %arg3: index, %arg4: index)
+    -> tensor<?x?x?xf32> {
+  %0 = tensor.expand_shape %arg0 [[0, 1], [2], [3]] output_shape [%arg1, %arg2, %arg3, %arg4]
+      : tensor<?x?x?xf32> into tensor<?x?x?x?xf32>
+  %1 = tensor.collapse_shape %0 [[0], [1], [2, 3]]
+      : tensor<?x?x?x?xf32> into tensor<?x?x?xf32>
+  return %1 : tensor<?x?x?xf32>
+}
+// CHECK-LABEL: @no_fold_parallel_collapse_of_expand_dynamic
+//       CHECK:   tensor.expand_shape
+//       CHECK:   %[[COLLAPSE:.+]] = tensor.collapse_shape
+//       CHECK:   return %[[COLLAPSE]]
+
+// -----
+
+func.func @fold_expand_of_collapse(%arg0 : tensor<3x4x4xf32>) -> tensor<3x4x4xf32> {
+  %0 = tensor.collapse_shape %arg0 [[0, 1], [2]]
+      : tensor<3x4x4xf32> into tensor<12x4xf32>
+  %1 = tensor.expand_shape %0 [[0, 1], [2]] output_shape [3, 4, 4]
+      : tensor<12x4xf32> into tensor<3x4x4xf32>
+  return %1 : tensor<3x4x4xf32>
+}
+// CHECK-LABEL: @fold_expand_of_collapse
+//   CHECK-NOT:   tensor.{{.*}}_shape
+
+// -----
+
+func.func @fold_expand_of_collapse_dynamic(%arg0 : tensor<?x4x?xf32>, %arg1: index, %arg2: index)
+    -> tensor<?x4x?xf32> {
+  %0 = tensor.collapse_shape %arg0 [[0, 1], [2]]
+      : tensor<?x4x?xf32> into tensor<?x?xf32>
+  %1 = tensor.expand_shape %0 [[0, 1], [2]] output_shape [%arg1, 4, %arg2]
+      : tensor<?x?xf32> into tensor<?x4x?xf32>
+  return %1 : tensor<?x4x?xf32>
+}
+// CHECK-LABEL: @fold_expand_of_collapse_dynamic
+//   CHECK-NOT:   tensor.{{.*}}_shape
+
+// -----
+
+func.func @no_fold_expand_of_collapse_dynamic(%arg0 : tensor<?x?x?xf32>, %arg1: index, %arg2: index, %arg3: index)
+    -> tensor<?x?x?xf32> {
+  %0 = tensor.collapse_shape %arg0 [[0, 1], [2]]
+      : tensor<?x?x?xf32> into tensor<?x?xf32>
+  %1 = tensor.expand_shape %0 [[0, 1], [2]] output_shape [%arg1, %arg2, %arg3]
+      : tensor<?x?xf32> into tensor<?x?x?xf32>
+  return %1 : tensor<?x?x?xf32>
+}
+// CHECK-LABEL: @no_fold_expand_of_collapse_dynamic
+//       CHECK:   tensor.collapse_shape
+//       CHECK:   %[[EXPAND:.+]] = tensor.expand_shape
+//       CHECK:   return %[[EXPAND]]
 
 // -----
 

@Max191
Copy link
Contributor Author

Max191 commented Jun 6, 2024

The reverse of this pattern (collapse_shape up through expand_shape) should also be implemented, but I'd rather leave that as another PR later.

@Max191 Max191 force-pushed the parallel-reshape-bubbling branch from 8b5a6be to c96c4ad Compare June 10, 2024 20:39
@Max191
Copy link
Contributor Author

Max191 commented Jun 10, 2024

rebased now

Copy link
Contributor

@MaheshRavishankar MaheshRavishankar left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Ok, I think I understand the code and this makes sense. Can you add a test for partially intersecting as well?


// Reshapes are parallel to each other if none of the reassociation indices
// have greater than 1 index for both reshapes.
for (auto [expandReassociation, collapseReassociation] :
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Such reshapes should just be folded away.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This comment means that there are no reassociations where the size is greater than 1 for both the expand and collapse at the same time. There could be cases where only one of the collapse or expand shape have size > 1, which would be parallel reshapes, but not identity reshapes. I can update the comment to be more clear.

Copy link
Contributor

@MaheshRavishankar MaheshRavishankar left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

These patterns dont have anything to do with Linalg ops. Can we mvoe this to TensorDialect. Probably need a populate* method there that you can include in this file.

MaheshRavishankar pushed a commit that referenced this pull request Aug 14, 2024
Refactored @Max191's PR #94637
to move it to `Tensor`

From the original PR
>This PR adds fusion by expansion patterns to push a tensor.expand_shape
up through a tensor.collapse_shape with non-intersecting reassociations.
Sometimes parallel collapse_shape ops like this can block propagation of
expand_shape ops, so this allows them to pass through each other.

I'm not sure if I put the code/tests in the right places, so let me know
where those go if they aren't.

cc @MaheshRavishankar @hanhanW

---------

Co-authored-by: Max Dawkins <max.dawkins@gmail.com>
bwendling pushed a commit to bwendling/llvm-project that referenced this pull request Aug 15, 2024
Refactored @Max191's PR llvm#94637
to move it to `Tensor`

From the original PR
>This PR adds fusion by expansion patterns to push a tensor.expand_shape
up through a tensor.collapse_shape with non-intersecting reassociations.
Sometimes parallel collapse_shape ops like this can block propagation of
expand_shape ops, so this allows them to pass through each other.

I'm not sure if I put the code/tests in the right places, so let me know
where those go if they aren't.

cc @MaheshRavishankar @hanhanW

---------

Co-authored-by: Max Dawkins <max.dawkins@gmail.com>
@IanWood1
Copy link
Contributor

#103401

@IanWood1 IanWood1 closed this Aug 15, 2024
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
None yet
Development

Successfully merging this pull request may close these issues.

4 participants